package com.ml4ai.demo;

import com.ml4ai.nn.BaseForwardNetwork;
import com.ml4ai.nn.Plot;
import com.ml4ai.nn.Recurrent;
import com.ml4ai.nn.core.Operation;
import com.ml4ai.nn.core.Toolkit;
import com.ml4ai.nn.core.Variable;
import com.ml4ai.nn.core.optimizers.Moment;
import com.ml4ai.nn.core.optimizers.NNOptimizer;
import org.jfree.data.category.DefaultCategoryDataset;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

public class RNNDemo {

    public static void main(String[] args) {
        int step = 10;
        Variable seed = new Variable(Nd4j.linspace(0, 2 * Math.PI, step).reshape(1, step, 1));
        Variable x = seed.cos();
        Variable y = seed.sin();
        RNN rnn = new RNN();
        NNOptimizer optimizer = new Moment(rnn.getParameters(), 1e-3, 0.95);
        Toolkit tool = new Toolkit();
        DefaultCategoryDataset dcd = Plot.createLineChart("train chart", "train");
        DefaultCategoryDataset cos = Plot.createLineChart("predict chart", "predict");
        for (int i = 0; i < 10000; i++) {
            Variable[] predicts = rnn.forward(x);
            Variable predict = predicts[0];
            Variable loss = predict.sub(y).square().mean();
            tool.grad2zero(loss);
            tool.backward(loss);
            optimizer.update();
            dcd.addValue(loss.data.scalar, "loss", String.valueOf(i));
            new Thread(() -> {
                cos.clear();
                for (int j = 0; j < 10; j++) {
                    cos.addValue(predict.data.tensor.getScalar(0, j, 0).data().getDouble(0), "predict", j + "");
                }
            }).start();
        }


    }


}

class RNN extends BaseForwardNetwork {

    private Recurrent.RNNCell a;
    private Recurrent.RNNCell b;
    private Recurrent.RNNCell c;
    private Variable fc;
    private Variable ou;

    RNN() {
        a = new Recurrent.RNNCell(1, 32, Operation.TanH);
        b = new Recurrent.RNNCell(32, 32, Operation.TanH);
        c = new Recurrent.RNNCell(32, 32, Operation.TanH);
        fc = new Variable(Nd4j.randn(new int[]{32, 8}));
        ou = new Variable(Nd4j.randn(8, 1));
    }

    @Override
    public Variable[] getParameters() {
        List<Variable> var = new LinkedList<>();
        var.addAll(Arrays.asList(a.getParameters()));
        var.addAll(Arrays.asList(b.getParameters()));
        var.addAll(Arrays.asList(c.getParameters()));
        var.add(fc);
        var.add(ou);
        return var.toArray(new Variable[0]);
    }

    public Variable[] forward(Variable... x) {
        Variable[] ro = c.forward(b.forward(a.forward(x))); //[batch_size,time_step,hidden_size]
        Variable rnn_outs = ro[0];
        int batch_size = rnn_outs.data.tensor.shape()[0];
        Variable[] batch_actived = new Variable[batch_size];
        for (int i = 0; i < batch_size; i++) {
            Variable time_step = rnn_outs.separate(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all());
            Variable actived = time_step.matMul(fc);
            actived = actived.matMul(ou);
            batch_actived[i] = actived.refactor(1, rnn_outs.data.tensor.shape()[1], 1);
        }
        Variable actived = null;
        for (int i = 0; i < batch_actived.length; i++) {
            if (i == 0)
                actived = batch_actived[0];
            else
                actived = actived.connect(batch_actived[i], 0);
        }
        return new Variable[]{actived};
    }

}